{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from itertools import product\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as tv\n",
    "\n",
    "from data import SimCLRAugment\n",
    "from models import Encoder, Projector\n",
    "\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "plt.rcParams['figure.figsize'] = (20,20)\n",
    "sns.set()\n",
    "np.random.seed(2020)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = f'./checkpoints/{md5}.pkl'\n",
    "checkpoint = torch.load(filename)\n",
    "\n",
    "model = Encoder(checkpoint['hparams'])\n",
    "model.load_state_dict(checkpoint['encoder_state_dict'])\n",
    "model.cuda()\n",
    "\n",
    "proj = Projector(checkpoint['hparams'])\n",
    "proj.load_state_dict(checkpoint['projector_state_dict'])\n",
    "proj.cuda()\n",
    "\n",
    "checkpoint['hparams'].as_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import datasets\n",
    "\n",
    "batchsize = 512\n",
    "train_transform = SimCLRAugment(checkpoint['hparams'], batchsize)\n",
    "\n",
    "cifar_train_data = datasets.CIFAR10(os.getenv('SLURM_TMPDIR')+'/data', train=True, transform=train_transform, download=False)\n",
    "cifar_test_data = datasets.CIFAR10(os.getenv('SLURM_TMPDIR')+'/data', train=False, transform=train_transform, download=False)\n",
    "\n",
    "svhn_train_data = datasets.SVHN(os.getenv('SLURM_TMPDIR')+\"/data/svhn\", split='train', transform=train_transform, download=False)\n",
    "fake_data = datasets.FakeData(4000, (3, 32, 32), transform=train_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def get_cos_dis(x, use_all=False, use_proj=False):\n",
    "  x = x.cuda()\n",
    "  z = model(x)\n",
    "  if use_proj:\n",
    "    z = proj(z)\n",
    "\n",
    "  znorm = z / torch.norm(z, 2, dim=1, keepdim=True)\n",
    "  cos_sim = torch.einsum('id,jd->ij', znorm, znorm) / 0.5\n",
    "\n",
    "  if use_all:\n",
    "    indices = torch.triu_indices(batchsize,batchsize, 1)\n",
    "    dist = cos_sim[indices[0], indices[1]]\n",
    "    var, mean = torch.var_mean(cos_sim[indices[0], indices[1]])\n",
    "  else:\n",
    "    dist = cos_sim[0,:]\n",
    "    var, mean = torch.var_mean(cos_sim[0, :])\n",
    "\n",
    "  return dist, mean, var\n",
    "\n",
    "\n",
    "def plot_density(dataset, name, count, axes, color_idx):\n",
    "  for idx, param in enumerate(product((False, True), (False, True))):\n",
    "    xs = []\n",
    "    means = []\n",
    "    variances = []\n",
    "    indices = np.random.randint(0, len(dataset), count)\n",
    "    for i in indices:\n",
    "      dist, mean, var = get_cos_dis(dataset[i][0], *param)\n",
    "      xs.append(dist.cpu().numpy())\n",
    "      means.append(mean.cpu().item())\n",
    "      variances.append(var.cpu().item())\n",
    "\n",
    "    axes[idx,0].set_title('Distributions')\n",
    "    axes[idx,0].hist(xs, histtype=u'bar', alpha=0.5, bins=100, stacked=True, density=True, color=[sns.color_palette()[color_idx] for _ in range(len(xs))], lw=0, label=name)\n",
    "    axes[idx,0].set_xlim(0, 2)\n",
    "    axes[idx,0].legend()\n",
    "\n",
    "    axes[idx,1].set_title('Means')\n",
    "    sns.distplot(means, label=name, ax=axes[idx,1])\n",
    "    axes[idx,1].legend()\n",
    "\n",
    "    axes[idx,2].set_title('Variances')\n",
    "    sns.distplot(variances, label=name, ax=axes[idx,2])\n",
    "    axes[idx,2].legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, axes = plt.subplots(4, 3)\n",
    "plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 0)\n",
    "plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, axes = plt.subplots(4, 3)\n",
    "plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 0)\n",
    "plot_density(cifar_test_data, 'CIFAR-10 Test', 20, axes, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, axes = plt.subplots(4, 3)\n",
    "plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 0)\n",
    "plot_density(fake_data, 'Fake Data', 20, axes, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, axes = plt.subplots(4, 3)\n",
    "plot_density(cifar_train_data, 'CIFAR-10 Train', 20, axes, 0)\n",
    "plot_density(svhn_train_data, 'SVHN Train', 20, axes, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (falr)",
   "language": "python",
   "name": "falr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}